package org.mat.framework.api.gateway.context;

import org.mat.framework.core.context.RequestContextConstants;
import org.mat.framework.core.context.RequestContextUtils;
import org.mat.framework.core.context.TraceIdAdaptor;
import org.mat.framework.core.context.TraceIdGenerator;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.MDC;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.util.CollectionUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.util.List;

/**
 * @ClassName: TraceIdFilter
 * @Date: 2021/8/26
 * @author: sunxinhe
 * @Version: 1.0
 * @Description: TODO
 */
@Slf4j
public class TracingFilter implements GlobalFilter, Ordered {

    private TraceIdAdaptor traceIdAdaptor;

    public TracingFilter(TraceIdAdaptor traceIdAdaptor) {
        this.traceIdAdaptor = traceIdAdaptor;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        String path = request.getPath().toString();
        log.info("进入了全局过滤器,path:[{}]", path);
        String traceId = null;

        // 读取 request
        List<String> traceIdHeaderList = request.getHeaders().get(RequestContextConstants.TRACE_ID_KEY_FOR_HTTP_HEADER);
        if (!CollectionUtils.isEmpty(traceIdHeaderList)) {
            traceId = traceIdHeaderList.get(0);
        }

        // 确保traceId非空，TODO 待解决WebFlux传递上下文问题，由于请求处理过程中WebFlux存在多次线程切换，spring cloud gateway暂时无法支持TraceContext，Skywalking官方也不支持，待后续优化。
        // traceId = checkTraceId(traceId);

        // 初始化请求上下文
        if (StringUtils.isNotBlank(traceId)) {
            RequestContextUtils.init(traceId);
            MDC.put(RequestContextConstants.TRACE_ID_KEY_FOR_MDC, traceId);
            request = request.mutate().header(RequestContextConstants.TRACE_ID_KEY_FOR_HTTP_HEADER, traceId).build();
        }

        //TODO 回写响应头
        return chain.filter(exchange.mutate().request(request).build());
    }


    private String checkTraceId(String traceId) {
        // 如果 traceIdAdaptor 非空，使用 traceIdAdaptor 初始化
        if (StringUtils.isEmpty(traceId) && traceIdAdaptor != null) {
            traceId = traceIdAdaptor.getTraceId();
            log.info("未检查到traceId，从适配器获取：traceId = {}", traceId);
        }
        // 如果 traceIdAdaptor 获取步骤执行之后，traceId仍为空，使用默认生成器生成
        if (StringUtils.isEmpty(traceId)) {
            traceId = TraceIdGenerator.generate();
            log.info("未检查到traceId，使用默认生成器生成：traceId = {}", traceId);
        }

        return traceId;
    }

    @Override
    public int getOrder() {
        return 0;
    }
}
